from typing import Optional, Union

import copy
from itertools import product

import numpy as np
from ConfigSpace.configuration_space import Configuration, ConfigurationSpace
from ConfigSpace.forbidden import ForbiddenAndConjunction, ForbiddenEqualsClause
from sklearn.base import ClassifierMixin

from autosklearn.askl_typing import FEAT_TYPE_TYPE
from autosklearn.pipeline.base import BasePipeline
from autosklearn.pipeline.components.classification import ClassifierChoice
from autosklearn.pipeline.components.data_preprocessing import DataPreprocessorChoice
from autosklearn.pipeline.components.data_preprocessing.balancing.balancing import (
    Balancing,
)
from autosklearn.pipeline.components.feature_preprocessing import (
    FeaturePreprocessorChoice,
)
from autosklearn.pipeline.constants import SPARSE


class SimpleClassificationPipeline(BasePipeline, ClassifierMixin):
    """This class implements the classification task.

    It implements a pipeline, which includes one preprocessing step and one
    classification algorithm. It can render a search space including all known
    classification and preprocessing algorithms.

    Contrary to the sklearn API it is not possible to enumerate the
    possible parameters in the __init__ function because we only know the
    available classifiers at runtime. For this reason the user must
    specifiy the parameters by passing an instance of
    ConfigSpace.configuration_space.Configuration.

    Parameters
    ----------
    config : ConfigSpace.configuration_space.Configuration
        The configuration to evaluate.

    random_state : Optional[int | RandomState]
        If int, random_state is the seed used by the random number generator;
        If RandomState instance, random_state is the random number generator;
        If None, the random number generator is the RandomState instance
        used by `np.random`.

    Attributes
    ----------
    _estimator : The underlying scikit-learn classification model. This
        variable is assigned after a call to the
        :meth:`autosklearn.pipeline.classification.SimpleClassificationPipeline
        .fit` method.

    _preprocessor : The underlying scikit-learn preprocessing algorithm. This
        variable is only assigned if a preprocessor is specified and
        after a call to the
        :meth:`autosklearn.pipeline.classification.SimpleClassificationPipeline
        .fit` method.

    See also
    --------

    References
    ----------

    Examples
    --------

    """

    def __init__(
        self,
        config: Optional[Configuration] = None,
        feat_type: Optional[FEAT_TYPE_TYPE] = None,
        steps=None,
        dataset_properties=None,
        include=None,
        exclude=None,
        random_state: Optional[Union[int, np.random.RandomState]] = None,
        init_params=None,
    ):
        self._output_dtype = np.int32
        if dataset_properties is None:
            dataset_properties = dict()
        if "target_type" not in dataset_properties:
            dataset_properties["target_type"] = "classification"
        super().__init__(
            feat_type=feat_type,
            config=config,
            steps=steps,
            dataset_properties=dataset_properties,
            include=include,
            exclude=exclude,
            random_state=random_state,
            init_params=init_params,
        )

    def fit_transformer(self, X, y, fit_params=None):

        if fit_params is None:
            fit_params = {}

        if self.config["balancing:strategy"] == "weighting":
            balancing = Balancing(strategy="weighting")
            _init_params, _fit_params = balancing.get_weights(
                y,
                self.config["classifier:__choice__"],
                self.config["feature_preprocessor:__choice__"],
                {},
                {},
            )
            _init_params.update(self.init_params)
            self.set_hyperparameters(
                feat_type=self.feat_type,
                configuration=self.config,
                init_params=_init_params,
            )

            if _fit_params is not None:
                fit_params.update(_fit_params)

        X, fit_params = super().fit_transformer(X, y, fit_params=fit_params)

        return X, fit_params

    def predict_proba(self, X, batch_size=None):
        """predict_proba.

        Parameters
        ----------
        X : array-like, shape = (n_samples, n_features)

        batch_size: int or None, defaults to None
            batch_size controls whether the pipeline will be
            called on small chunks of the data. Useful when calling the
            predict method on the whole array X results in a MemoryError.

        Returns
        -------
        array, shape=(n_samples,) if n_classes == 2 else (n_samples, n_classes)
        """
        if batch_size is None:
            return super().predict_proba(X)

        else:
            if not isinstance(batch_size, int):
                raise ValueError(
                    "Argument 'batch_size' must be of type int, "
                    "but is '%s'" % type(batch_size)
                )
            if batch_size <= 0:
                raise ValueError(
                    "Argument 'batch_size' must be positive, " "but is %d" % batch_size
                )

            else:
                # Probe for the target array dimensions
                target = self.predict_proba(X[0:2].copy())

                y = np.zeros((X.shape[0], target.shape[1]), dtype=np.float32)

                for k in range(max(1, int(np.ceil(float(X.shape[0]) / batch_size)))):
                    batch_from = k * batch_size
                    batch_to = min([(k + 1) * batch_size, X.shape[0]])
                    pred_prob = self.predict_proba(
                        X[batch_from:batch_to], batch_size=None
                    )
                    y[batch_from:batch_to] = pred_prob.astype(np.float32)

                return y

    def _get_hyperparameter_search_space(
        self,
        feat_type: Optional[FEAT_TYPE_TYPE] = None,
        include=None,
        exclude=None,
        dataset_properties=None,
    ):
        """Create the hyperparameter configuration space.

        Parameters
        ----------
        feat_type : dict, maps columns to there datatypes

        include : dict (optional, default=None)

        Returns
        -------
        cs : ConfigSpace.configuration_space.Configuration
            The configuration space describing the SimpleRegressionClassifier.
        """
        cs = ConfigurationSpace()

        if dataset_properties is None or not isinstance(dataset_properties, dict):
            dataset_properties = dict()
        if "target_type" not in dataset_properties:
            dataset_properties["target_type"] = "classification"
        if dataset_properties["target_type"] != "classification":
            dataset_properties["target_type"] = "classification"

        if "sparse" not in dataset_properties:
            # This dataset is probably dense
            dataset_properties["sparse"] = False

        cs = self._get_base_search_space(
            cs=cs,
            feat_type=feat_type,
            dataset_properties=dataset_properties,
            exclude=exclude,
            include=include,
            pipeline=self.steps,
        )

        classifiers = cs.get_hyperparameter("classifier:__choice__").choices
        preprocessors = cs.get_hyperparameter("feature_preprocessor:__choice__").choices
        available_classifiers = self._final_estimator.get_available_components(
            dataset_properties
        )

        possible_default_classifier = copy.copy(list(available_classifiers.keys()))
        default = cs.get_hyperparameter("classifier:__choice__").default_value
        del possible_default_classifier[possible_default_classifier.index(default)]

        # A classifier which can handle sparse data after the densifier is
        # forbidden for memory issues
        for key in classifiers:
            if SPARSE in available_classifiers[key].get_properties()["input"]:
                if "densifier" in preprocessors:
                    while True:
                        try:
                            forb_cls = ForbiddenEqualsClause(
                                cs.get_hyperparameter("classifier:__choice__"), key
                            )
                            forb_fpp = ForbiddenEqualsClause(
                                cs.get_hyperparameter(
                                    "feature_preprocessor:__choice__"
                                ),
                                "densifier",
                            )
                            cs.add_forbidden_clause(
                                ForbiddenAndConjunction(forb_cls, forb_fpp)
                            )
                            # Success
                            break
                        except ValueError:
                            # Change the default and try again
                            try:
                                default = possible_default_classifier.pop()
                            except IndexError:
                                raise ValueError(
                                    "Cannot find a legal default configuration."
                                )
                            cs.get_hyperparameter(
                                "classifier:__choice__"
                            ).default_value = default

        # which would take too long
        # Combinations of non-linear models with feature learning:
        classifiers_ = [
            "adaboost",
            "decision_tree",
            "extra_trees",
            "gradient_boosting",
            "k_nearest_neighbors",
            "libsvm_svc",
            "mlp",
            "random_forest",
            "gaussian_nb",
        ]
        feature_learning = [
            "kernel_pca",
            "kitchen_sinks",
            "nystroem_sampler",
        ]

        for c, f in product(classifiers_, feature_learning):
            if c not in classifiers:
                continue
            if f not in preprocessors:
                continue
            while True:
                try:
                    cs.add_forbidden_clause(
                        ForbiddenAndConjunction(
                            ForbiddenEqualsClause(
                                cs.get_hyperparameter("classifier:__choice__"), c
                            ),
                            ForbiddenEqualsClause(
                                cs.get_hyperparameter(
                                    "feature_preprocessor:__choice__"
                                ),
                                f,
                            ),
                        )
                    )
                    break
                except KeyError:
                    break
                except ValueError:
                    # Change the default and try again
                    try:
                        default = possible_default_classifier.pop()
                    except IndexError:
                        raise ValueError("Cannot find a legal default configuration.")
                    cs.get_hyperparameter(
                        "classifier:__choice__"
                    ).default_value = default

        # Won't work
        # Multinomial NB etc don't use with features learning, pca etc
        classifiers_ = ["multinomial_nb"]
        preproc_with_negative_X = [
            "kitchen_sinks",
            "pca",
            "truncatedSVD",
            "fast_ica",
            "kernel_pca",
            "nystroem_sampler",
        ]

        for c, f in product(classifiers_, preproc_with_negative_X):
            if c not in classifiers:
                continue
            if f not in preprocessors:
                continue
            while True:
                try:
                    cs.add_forbidden_clause(
                        ForbiddenAndConjunction(
                            ForbiddenEqualsClause(
                                cs.get_hyperparameter(
                                    "feature_preprocessor:__choice__"
                                ),
                                f,
                            ),
                            ForbiddenEqualsClause(
                                cs.get_hyperparameter("classifier:__choice__"), c
                            ),
                        )
                    )
                    break
                except KeyError:
                    break
                except ValueError:
                    # Change the default and try again
                    try:
                        default = possible_default_classifier.pop()
                    except IndexError:
                        raise ValueError("Cannot find a legal default configuration.")
                    cs.get_hyperparameter(
                        "classifier:__choice__"
                    ).default_value = default

        self.configuration_space = cs
        self.dataset_properties = dataset_properties
        return cs

    def _get_pipeline_steps(
        self, dataset_properties, feat_type: Optional[FEAT_TYPE_TYPE] = None
    ):
        steps = []

        default_dataset_properties = {"target_type": "classification"}
        if dataset_properties is not None and isinstance(dataset_properties, dict):
            default_dataset_properties.update(dataset_properties)

        steps.extend(
            [
                [
                    "data_preprocessor",
                    DataPreprocessorChoice(
                        feat_type=feat_type,
                        dataset_properties=default_dataset_properties,
                        random_state=self.random_state,
                    ),
                ],
                ["balancing", Balancing(random_state=self.random_state)],
                [
                    "feature_preprocessor",
                    FeaturePreprocessorChoice(
                        feat_type=feat_type,
                        dataset_properties=default_dataset_properties,
                        random_state=self.random_state,
                    ),
                ],
                [
                    "classifier",
                    ClassifierChoice(
                        feat_type=feat_type,
                        dataset_properties=default_dataset_properties,
                        random_state=self.random_state,
                    ),
                ],
            ]
        )

        return steps

    def _get_estimator_hyperparameter_name(self):
        return "classifier"
